spikeinterface motion estimation
motion estimation in spikeinterface¶
In 2021 spikeinterface prokject have started to implemented sortingcompinents a modular module for spike sorting steps.
Here an overview for motion (aka drift) esstimation and correction of the work-in-progress integration.
This notebook will be based on the open dataset from Nick Steinmetz published in 2021 "Imposed motion datasets" from Steinmetz et al. Science 2021 https://figshare.com/articles/dataset/_Imposed_motion_datasets_from_Steinmetz_et_al_Science_2021/14024495
The motion estiomation is done in several modular steps:
- detect peaks
- localize peaks:
- "center of of mass"
- "monopolar_triangulation" by Julien Boussard and Erdem Varol https://openreview.net/pdf?id=ohfi44BZPC4
- estimation motion:
- rigid or non rigid
- "decentralize" by Erdem Varol and Julien Boussard DOI : 10.1109/ICASSP39728.2021.9414145
- "motion cloud" by Julien Boussard (not implemented yet)
Here we will show this chain:
- detect peak > localize peaks with "monopolar_triangulation" > estimation motion "decentralize"
%load_ext autoreload
%autoreload 2
from pathlib import Path
import spikeinterface.full as si
import numpy as np
import matplotlib.pyplot as plt
plt.rcParams["figure.figsize"] = (20, 12)
from probeinterface.plotting import plot_probe
from spikeinterface.sortingcomponents import detect_peaks
from spikeinterface.sortingcomponents import localize_peaks
# local folder
base_folder = Path('/mnt/data/sam/DataSpikeSorting/imposed_motion_nick')
dataset_folder = base_folder / 'dataset1/NP1'
preprocess_folder = base_folder / 'dataset1_NP1_preprocessed'
peak_folder = base_folder / 'dataset1_NP1_peaks'
peak_folder.mkdir(exist_ok=True)
# global kwargs for parallel computing
job_kwargs = dict(
n_jobs=40,
chunk_memory='10M',
progress_bar=True,
)
# read the file
rec = si.read_spikeglx(dataset_folder)
rec
fig, ax = plt.subplots()
plot_probe(rec.get_probe(), ax=ax)
ax.set_ylim(-150, 200)
preprocess¶
This take 4 min for 30min of signals
rec_filtered = si.bandpass_filter(rec, freq_min=300., freq_max=6000.)
rec_preprocessed = si.common_reference(rec_filtered, reference='global', operator='median')
rec_preprocessed.save(folder=preprocess_folder, **job_kwargs)
# load back
rec_preprocessed = si.load_extractor(preprocess_folder)
rec_preprocessed
# plot and check spikes
si.plot_timeseries(rec_preprocessed, time_range=(100, 110), channel_ids=rec.channel_ids[50:60])
estimate noise¶
noise_levels = si.get_noise_levels(rec_preprocessed, return_scaled=False)
fig, ax = plt.subplots(figsize=(8,6))
ax.hist(noise_levels, bins=np.arange(0,10, 1))
ax.set_title('noise across channel')
detect peaks¶
This take 1min30s
from spikeinterface.sortingcomponents import detect_peaks
peaks = detect_peaks(rec_preprocessed, method='locally_exclusive', local_radius_um=100,
peak_sign='neg', detect_threshold=5, n_shifts=5,
noise_levels=noise_levels, **job_kwargs)
np.save(peak_folder / 'peaks.npy', peaks)
# load back
peaks = np.load(peak_folder / 'peaks.npy')
print(peaks.shape)
from spikeinterface.sortingcomponents import localize_peaks
peak_locations = localize_peaks(rec_preprocessed, peaks,
ms_before=0.3, ms_after=0.6,
method='center_of_mass', method_kwargs={'local_radius_um': 100.},
**job_kwargs)
np.save(peak_folder / 'peak_locations_center_of_mass.npy', peak_locations)
print(peak_locations.shape)
peak_locations = localize_peaks(rec_preprocessed, peaks,
ms_before=0.3, ms_after=0.6,
method='monopolar_triangulation', method_kwargs={'local_radius_um': 100., 'max_distance_um': 1000.},
**job_kwargs)
np.save(peak_folder / 'peak_locations_monopolar_triangulation.npy', peak_locations)
print(peak_locations.shape)
# load back
# peak_locations = np.load(peak_folder / 'peak_locations_center_of_mass.npy')
peak_locations = np.load(peak_folder / 'peak_locations_monopolar_triangulation.npy')
print(peak_locations)
plot peak on probe¶
probe = rec_preprocessed.get_probe()
fig, ax = plt.subplots(figsize=(15, 10))
plot_probe(probe, ax=ax)
ax.scatter(peak_locations['x'], peak_locations['y'], color='k', s=1, alpha=0.002)
# ax.set_ylim(2400, 2900)
ax.set_ylim(1500, 2500)
plot peak depth vs time¶
fig, ax = plt.subplots()
x = peaks['sample_ind'] / rec_preprocessed.get_sampling_frequency()
y = peak_locations['y']
ax.scatter(x, y, s=1, color='k', alpha=0.05)
ax.set_ylim(1300, 2500)
motion estimate : rigid with decentralized¶
from spikeinterface.sortingcomponents import (estimate_motion, make_motion_histogram,
compute_pairwise_displacement, compute_global_displacement)
bin_um = 2
bin_duration_s=5.
motion_histogram, temporal_bins, spatial_bins = make_motion_histogram(rec_preprocessed, peaks,
peak_locations=peak_locations,
bin_um=bin_um, bin_duration_s=bin_duration_s,
direction='y', weight_with_amplitude=False)
print(motion_histogram.shape, temporal_bins.size, spatial_bins.size)
fig, ax = plt.subplots()
extent = (temporal_bins[0], temporal_bins[-1], spatial_bins[0], spatial_bins[-1])
im = ax.imshow(motion_histogram.T, interpolation='nearest',
origin='lower', aspect='auto', extent=extent)
im.set_clim(0, 15)
ax.set_ylim(1300, 2500)
ax.set_xlabel('time[s]')
ax.set_ylabel('depth[um]')
pariwise displacement from the motion histogram¶
pairwise_displacement = compute_pairwise_displacement(motion_histogram, bin_um, method='conv2d', )
np.save(peak_folder / 'pairwise_displacement_conv2d.npy', pairwise_displacement)
fig, ax = plt.subplots()
extent = (temporal_bins[0], temporal_bins[-1], temporal_bins[0], temporal_bins[-1])
# extent = None
im = ax.imshow(pairwise_displacement, interpolation='nearest',
cmap='PiYG', origin='lower', aspect='auto', extent=extent)
im.set_clim(-40, 40)
ax.set_aspect('equal')
fig.colorbar(im)
estimate motion (rigid) from the pariwise discplaement¶
motion = compute_global_displacement(pairwise_displacement)
fig, ax = plt.subplots()
ax.plot(temporal_bins[:-1], motion)
motion estimation with one unique funtion¶
Internanly estimate_motion() do:
- make_motion_histogram()
- compute_pairwise_displacement()
- compute_global_displacement()
motion, temporal_bins, spatial_bins = estimate_motion(rec_preprocessed, peaks, peak_locations=peak_locations,
direction='y', bin_duration_s=5., bin_um=10.,
method='decentralized_registration', method_kwargs={},
non_rigid_kwargs=None,
progress_bar=True, verbose=True)
fig, ax = plt.subplots()
x = peaks['sample_ind'] / rec_preprocessed.get_sampling_frequency()
y = peak_locations['y']
ax.scatter(x, y, s=1, color='k', alpha=0.05)
ax.set_ylim(1300, 2500)
ax.plot(temporal_bins[:-1], motion + 2000, color='r')
ax.set_xlabel('times[s]')
ax.set_ylabel('motion [um]')
motion estimation non rigid¶
motion, temporal_bins, spatial_bins = estimate_motion(rec_preprocessed, peaks, peak_locations=peak_locations,
direction='y', bin_duration_s=5., bin_um=10.,
method='decentralized_registration', method_kwargs={},
non_rigid_kwargs=dict(bin_step_um=200),
progress_bar=True, verbose=True)
print(motion.shape)
print(temporal_bins.shape)
fs = rec_preprocessed.get_sampling_frequency()
fig, ax = plt.subplots()
ax.scatter(peaks['sample_ind'] / fs, peak_locations['y'], color='k', s=0.1, alpha=0.05)
for i, s_bins in enumerate(spatial_bins):
# several motion vector
ax.plot(temporal_bins[:-1], motion[:, i] + spatial_bins[i], color='r')
ax.set_ylim(1300, 2500)
ax.set_xlabel('times[s]')
ax.set_ylabel('motion [um]')